package io.indexr.vlt.segment.index;

import org.apache.commons.io.FileUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;

import io.indexr.data.Cleanable;
import io.indexr.data.DictStruct;
import io.indexr.data.Freeable;
import io.indexr.data.OffheapBytes;
import io.indexr.io.ByteBufferReader;
import io.indexr.io.ByteBufferWriter;
import io.indexr.segment.ColumnType;
import io.indexr.util.ByteBufferUtil;
import io.indexr.util.BytesUtil;
import io.indexr.util.DirectBitMap;
import io.indexr.util.IOUtil;
import io.indexr.util.MemoryUtil;
import io.indexr.util.Try;

public class DictMerge implements Freeable, Cleanable {
    private static final Logger logger = LoggerFactory.getLogger(DictMerge.class);
    private static final int BUFFER_SIZE = 1 << 20;
    private static final int BITMAP_MERGE_SLOT = 32;

    private Path dir;
    private Path entryFilePath;
    private Path entryOffsetFilePath; // For strings
    private Path bitmapFilePath;
    private FileChannel entryFile;
    private FileChannel entryOffsetFile; // For strings
    private FileChannel bitmapFile;

    private int entryCount; // i.e. the bitmap count
    private int bitmapMergeSlotCount;

    private long entryFileSize;
    private long entryOffsetFileSize;
    private long bitmapFileSize;

    private ByteBuffer entryBuffer = ByteBufferUtil.allocateDirect(BUFFER_SIZE);
    private ByteBuffer entryOffsetBuffer = ByteBufferUtil.allocateDirect(BUFFER_SIZE); // For strings
    private ByteBuffer bitmapBuffer = ByteBufferUtil.allocateDirect(BUFFER_SIZE);

    private int strEntryLastOffset; // For strings
    private long entryBufferAddr = MemoryUtil.getAddress(entryBuffer);

    public DictMerge() throws IOException {
        dir = Files.createTempDirectory("indexr_dictmerge");
    }

    @Override
    public void free() {
        if (entryBuffer != null) {
            ByteBufferUtil.free(entryBuffer);
            entryBuffer = null;
        }
        if (entryOffsetBuffer != null) {
            ByteBufferUtil.free(entryOffsetBuffer);
            entryOffsetBuffer = null;
        }
        if (bitmapBuffer != null) {
            ByteBufferUtil.free(bitmapBuffer);
            bitmapBuffer = null;
        }
    }

    /**
     * Delete all files generated by this class.
     */
    @Override
    public void clean() throws IOException {
        free();
        FileUtils.deleteDirectory(dir.toFile());
    }

    // @formatter:off
    public Path entryFilePath() {return entryFilePath;}
    public Path entryOffsetFilePath() {return entryOffsetFilePath;}
    public Path bitmapFilePath() {return bitmapFilePath;}
    public int entryCount() {return entryCount;}
    public int bitmapMergeSlotCount() {return bitmapMergeSlotCount;}
    public long entryFileSize() {return entryFileSize;}
    public long entryOffsetFileSize() {return entryOffsetFileSize;}
    public long bitmapFileSize() {return bitmapFileSize;}
    // @formatter:on

    public void merge(byte dataType, DictStruct[] structs) throws IOException {
        // Initialize.
        entryBuffer.clear();
        entryOffsetBuffer.clear();
        bitmapBuffer.clear();
        strEntryLastOffset = 0;

        DirectBitMap bitmap = new DirectBitMap(structs.length);

        try {
            entryFilePath = dir.resolve("entry");
            entryFile = FileChannel.open(
                    entryFilePath,
                    StandardOpenOption.WRITE,
                    StandardOpenOption.CREATE,
                    StandardOpenOption.TRUNCATE_EXISTING);

            bitmapFilePath = dir.resolve("bitmap");
            bitmapFile = FileChannel.open(
                    bitmapFilePath,
                    StandardOpenOption.WRITE,
                    StandardOpenOption.READ,
                    StandardOpenOption.CREATE,
                    StandardOpenOption.TRUNCATE_EXISTING);

            if (dataType != ColumnType.STRING) {
                entryOffsetFilePath = null;
                entryOffsetFile = null;
            } else {
                entryOffsetFilePath = dir.resolve("entry_offset");
                entryOffsetFile = FileChannel.open(
                        entryOffsetFilePath,
                        StandardOpenOption.WRITE,
                        StandardOpenOption.CREATE,
                        StandardOpenOption.TRUNCATE_EXISTING);
            }

            switch (dataType) {
                case ColumnType.INT:
                    mergeInts(structs, bitmap);
                    break;
                case ColumnType.LONG:
                    mergeLongs(structs, bitmap);
                    break;
                case ColumnType.FLOAT:
                    mergeFloats(structs, bitmap);
                    break;
                case ColumnType.DOUBLE:
                    mergeDoubles(structs, bitmap);
                    break;
                case ColumnType.STRING:
                    mergeStrings(structs, bitmap);
                    break;
                default:
                    throw new IllegalStateException("Illegal dataType: " + dataType);
            }

            if (entryCount >= BITMAP_MERGE_SLOT << 1) {
                // Sotre merge bitmaps.
                bitmapMergeSlotCount = BITMAP_MERGE_SLOT;

                ByteBufferReader bitmapFileReader = ByteBufferReader.of(bitmapFile, 0, null);
                ByteBufferWriter mergeBitmapWriter = ByteBufferWriter.of(bitmapFile, bitmapFile.size(), null);

                MergeBitMapUtil.mergeBitMaps(
                        bitmapFileReader,
                        mergeBitmapWriter,
                        structs.length,
                        entryCount,
                        bitmapMergeSlotCount);
            } else {
                bitmapMergeSlotCount = 0;
            }

            entryFileSize = entryFile.size();
            entryOffsetFileSize = entryOffsetFile == null ? 0 : entryOffsetFile.size();
            bitmapFileSize = bitmapFile.size();
        } finally {
            if (entryFile != null) {
                Try.on(() -> entryFile.force(false), logger);
                IOUtil.closeQuietly(entryFile);
                entryFile = null;
            }
            if (entryOffsetFile != null) {
                Try.on(() -> entryOffsetFile.force(false), logger);
                IOUtil.closeQuietly(entryOffsetFile);
                entryOffsetFile = null;
            }
            if (bitmapFile != null) {
                Try.on(() -> bitmapFile.force(false), logger);
                IOUtil.closeQuietly(bitmapFile);
                bitmapFile = null;
            }

            bitmap.free();
        }
    }

    private void mergeInts(DictStruct[] structs, DirectBitMap bitmap) throws IOException {
        int structSize = structs.length;
        int[] dictEntryCounts = new int[structSize];
        long[] dictEntryAddrs = new long[structSize];
        for (int i = 0; i < structSize; i++) {
            dictEntryCounts[i] = structs[i].dictEntryCount();
            dictEntryAddrs[i] = structs[i].dictEntriesAddr();
        }

        ByteBuffer bitmapMemory = (ByteBuffer) bitmap.attach;
        int[] dictOffsets = new int[structSize];

        int curEntry = Integer.MAX_VALUE;
        while (true) {
            int minEntry = Integer.MAX_VALUE;
            int structId = -1;
            for (int i = 0; i < structSize; i++) {
                int offset = dictOffsets[i];
                if (offset >= dictEntryCounts[i]) {
                    continue;
                }
                int e = MemoryUtil.getInt(dictEntryAddrs[i] + (offset << 2));
                if (e <= minEntry) {
                    minEntry = e;
                    structId = i;
                }
            }
            if (structId == -1) {
                break;
            }

            // Forward offset
            dictOffsets[structId]++;

            assert minEntry >= curEntry || curEntry == Integer.MAX_VALUE : String.format("%s, %s", minEntry, curEntry);

            if (minEntry <= curEntry) {
                bitmap.set(structId);
                curEntry = minEntry;
            } else {
                saveEntryInt(curEntry, false);
                saveBitmap(bitmapMemory, false);

                curEntry = minEntry;
                bitmap.clear();
                bitmap.set(structId);
            }
        }
        saveEntryInt(curEntry, true);
        saveBitmap(bitmapMemory, true);
    }

    private void mergeLongs(DictStruct[] structs, DirectBitMap bitmap) throws IOException {
        int structSize = structs.length;
        int[] dictEntryCounts = new int[structSize];
        long[] dictEntryAddrs = new long[structSize];
        for (int i = 0; i < structSize; i++) {
            dictEntryCounts[i] = structs[i].dictEntryCount();
            dictEntryAddrs[i] = structs[i].dictEntriesAddr();
        }

        ByteBuffer bitmapMemory = (ByteBuffer) bitmap.attach;
        int[] dictOffsets = new int[structSize];

        long curEntry = Long.MAX_VALUE;
        while (true) {
            long minEntry = Long.MAX_VALUE;
            int structId = -1;
            for (int i = 0; i < structSize; i++) {
                int offset = dictOffsets[i];
                if (offset >= dictEntryCounts[i]) {
                    continue;
                }
                long e = MemoryUtil.getLong(dictEntryAddrs[i] + (offset << 3));
                if (e <= minEntry) {
                    minEntry = e;
                    structId = i;
                }
            }
            if (structId == -1) {
                break;
            }

            // Forward offset
            dictOffsets[structId]++;

            assert minEntry >= curEntry || curEntry == Long.MAX_VALUE : String.format("%s, %s", minEntry, curEntry);

            if (minEntry <= curEntry) {
                bitmap.set(structId);
                curEntry = minEntry;
            } else {
                saveEntryLong(curEntry, false);
                saveBitmap(bitmapMemory, false);

                curEntry = minEntry;
                bitmap.clear();
                bitmap.set(structId);
            }
        }
        saveEntryLong(curEntry, true);
        saveBitmap(bitmapMemory, true);
    }

    private void mergeFloats(DictStruct[] structs, DirectBitMap bitmap) throws IOException {
        int structSize = structs.length;
        int[] dictEntryCounts = new int[structSize];
        long[] dictEntryAddrs = new long[structSize];
        for (int i = 0; i < structSize; i++) {
            dictEntryCounts[i] = structs[i].dictEntryCount();
            dictEntryAddrs[i] = structs[i].dictEntriesAddr();
        }

        ByteBuffer bitmapMemory = (ByteBuffer) bitmap.attach;
        int[] dictOffsets = new int[structSize];

        float curEntry = Float.MAX_VALUE;
        while (true) {
            float minEntry = Float.MAX_VALUE;
            int structId = -1;
            for (int i = 0; i < structSize; i++) {
                int offset = dictOffsets[i];
                if (offset >= dictEntryCounts[i]) {
                    continue;
                }
                float e = MemoryUtil.getFloat(dictEntryAddrs[i] + (offset << 2));
                if (e <= minEntry) {
                    minEntry = e;
                    structId = i;
                }
            }
            if (structId == -1) {
                break;
            }

            // Forward offset
            dictOffsets[structId]++;

            assert minEntry >= curEntry || curEntry == Float.MAX_VALUE : String.format("%s, %s", minEntry, curEntry);

            if (minEntry <= curEntry) {
                bitmap.set(structId);
                curEntry = minEntry;
            } else {
                saveEntryFloat(curEntry, false);
                saveBitmap(bitmapMemory, false);

                curEntry = minEntry;
                bitmap.clear();
                bitmap.set(structId);
            }
        }
        saveEntryFloat(curEntry, true);
        saveBitmap(bitmapMemory, true);
    }

    private void mergeDoubles(DictStruct[] structs, DirectBitMap bitmap) throws IOException {
        int structSize = structs.length;
        int[] dictEntryCounts = new int[structSize];
        long[] dictEntryAddrs = new long[structSize];
        for (int i = 0; i < structSize; i++) {
            dictEntryCounts[i] = structs[i].dictEntryCount();
            dictEntryAddrs[i] = structs[i].dictEntriesAddr();
        }

        ByteBuffer bitmapMemory = (ByteBuffer) bitmap.attach;
        int[] dictOffsets = new int[structSize];

        double curEntry = Double.MAX_VALUE;
        while (true) {
            double minEntry = Double.MAX_VALUE;
            int structId = -1;
            for (int i = 0; i < structSize; i++) {
                int offset = dictOffsets[i];
                if (offset >= dictEntryCounts[i]) {
                    continue;
                }
                double e = MemoryUtil.getDouble(dictEntryAddrs[i] + (offset << 3));
                if (e <= minEntry) {
                    minEntry = e;
                    structId = i;
                }
            }
            if (structId == -1) {
                break;
            }

            // Forward offset
            dictOffsets[structId]++;

            assert minEntry >= curEntry || curEntry == Double.MAX_VALUE : String.format("%s, %s", minEntry, curEntry);

            if (minEntry <= curEntry) {
                bitmap.set(structId);
                curEntry = minEntry;
            } else {
                saveEntryDouble(curEntry, false);
                saveBitmap(bitmapMemory, false);

                curEntry = minEntry;
                bitmap.clear();
                bitmap.set(structId);
            }
        }
        saveEntryDouble(curEntry, true);
        saveBitmap(bitmapMemory, true);
    }

    private void mergeStrings(DictStruct[] structs, DirectBitMap bitmap) throws IOException {
        int structSize = structs.length;
        int[] dictEntryCounts = new int[structSize];
        long[] dictEntryAddrs = new long[structSize];
        for (int i = 0; i < structSize; i++) {
            dictEntryCounts[i] = structs[i].dictEntryCount();
            dictEntryAddrs[i] = structs[i].dictEntriesAddr();
        }

        ByteBuffer bitmapMemory = (ByteBuffer) bitmap.attach;
        int[] dictOffsets = new int[structSize];

        OffheapBytes curEntry = new OffheapBytes();

        OffheapBytes e = new OffheapBytes();
        OffheapBytes minEntry = new OffheapBytes();

        while (true) {
            minEntry.set(0, 0);

            int structId = -1;
            for (int i = 0; i < structSize; i++) {
                int offset = dictOffsets[i];
                if (offset >= dictEntryCounts[i]) {
                    continue;
                }
                structs[i].stringDictEntries().getString(offset, e);
                if (minEntry.addr() == 0
                        || BytesUtil.compareBytes(e.addr(), e.len(), minEntry.addr(), minEntry.len()) <= 0) {
                    minEntry.set(e);
                    structId = i;
                }
            }
            if (structId == -1) {
                break;
            }

            // Forward offset
            dictOffsets[structId]++;

            assert curEntry.addr() == 0
                    || BytesUtil.compareBytes(minEntry.addr(), minEntry.len(), curEntry.addr(), curEntry.len()) >= 0;

            if (curEntry.addr() == 0
                    || BytesUtil.compareBytes(minEntry.addr(), minEntry.len(), curEntry.addr(), curEntry.len()) == 0) {
                bitmap.set(structId);
                curEntry.set(minEntry);
            } else {
                saveEntryString(curEntry, false);
                saveBitmap(bitmapMemory, false);

                curEntry.set(minEntry);
                bitmap.clear();
                bitmap.set(structId);
            }
        }
        saveEntryString(curEntry, true);
        saveBitmap(bitmapMemory, true);
    }

    // Only for test
    public Path combineStringOffsetAndEntryFile() throws IOException {
        FileChannel entryFile = FileChannel.open(
                entryFilePath,
                StandardOpenOption.READ);
        FileChannel entryOffsetFile = FileChannel.open(
                entryOffsetFilePath,
                StandardOpenOption.READ);
        Path path = dir.resolve("str");
        FileChannel stringsFile = FileChannel.open(
                path,
                StandardOpenOption.WRITE,
                StandardOpenOption.CREATE,
                StandardOpenOption.TRUNCATE_EXISTING);

        ByteBuffer entryFileBuffer = entryFile.map(FileChannel.MapMode.READ_ONLY, 0, entryFile.size());
        ByteBuffer entryOffsetFileBuffer = entryOffsetFile.map(FileChannel.MapMode.READ_ONLY, 0, entryOffsetFile.size());

        try {
            stringsFile.write(entryOffsetFileBuffer);
            stringsFile.write(entryFileBuffer);
        } finally {
            IOUtil.closeQuietly(entryFile);
            IOUtil.closeQuietly(entryOffsetFile);
            Try.on(() -> stringsFile.force(false), logger);
            IOUtil.closeQuietly(stringsFile);
        }
        return path;
    }


    private void saveBitmap(ByteBuffer bitmapMemory, boolean isFlush) throws IOException {
        bitmapMemory.clear();
        if (bitmapBuffer.remaining() >= bitmapMemory.capacity()) {
            bitmapBuffer.put(bitmapMemory);
        } else {
            bitmapBuffer.flip();
            bitmapFile.write(bitmapBuffer);

            bitmapBuffer.clear();
            bitmapBuffer.put(bitmapMemory);
        }
        if (isFlush) {
            bitmapBuffer.flip();
            bitmapFile.write(bitmapBuffer);
            bitmapBuffer.clear();
        }
    }


    private void saveEntryInt(int curEntry, boolean isFlush) throws IOException {
        entryCount++;
        if (entryBuffer.remaining() >= 4) {
            entryBuffer.putInt(curEntry);
        } else {
            entryBuffer.flip();
            entryFile.write(entryBuffer);
            entryBuffer.clear();

            entryBuffer.putInt(curEntry);
        }
        if (isFlush) {
            entryBuffer.flip();
            entryFile.write(entryBuffer);
            entryBuffer.clear();
        }
    }

    private void saveEntryLong(long curEntry, boolean isFlush) throws IOException {
        entryCount++;
        if (entryBuffer.remaining() >= 8) {
            entryBuffer.putLong(curEntry);
        } else {
            entryBuffer.flip();
            entryFile.write(entryBuffer);
            entryBuffer.clear();

            entryBuffer.putLong(curEntry);
        }
        if (isFlush) {
            entryBuffer.flip();
            entryFile.write(entryBuffer);
            entryBuffer.clear();
        }
    }

    private void saveEntryFloat(float curEntry, boolean isFlush) throws IOException {
        entryCount++;
        if (entryBuffer.remaining() >= 4) {
            entryBuffer.putFloat(curEntry);
        } else {
            entryBuffer.flip();
            entryFile.write(entryBuffer);
            entryBuffer.clear();

            entryBuffer.putFloat(curEntry);
        }
        if (isFlush) {
            entryBuffer.flip();
            entryFile.write(entryBuffer);
            entryBuffer.clear();
        }
    }

    private void saveEntryDouble(double curEntry, boolean isFlush) throws IOException {
        entryCount++;
        if (entryBuffer.remaining() >= 8) {
            entryBuffer.putDouble(curEntry);
        } else {
            entryBuffer.flip();
            entryFile.write(entryBuffer);
            entryBuffer.clear();

            entryBuffer.putDouble(curEntry);
        }
        if (isFlush) {
            entryBuffer.flip();
            entryFile.write(entryBuffer);
            entryBuffer.clear();
        }
    }

    private void saveEntryString(OffheapBytes curEntry, boolean isFlush) throws IOException {
        entryCount++;
        if (entryOffsetBuffer.remaining() >= 4) {
            entryOffsetBuffer.putInt(strEntryLastOffset);
            strEntryLastOffset += curEntry.len();
        } else {
            entryOffsetBuffer.flip();
            entryOffsetFile.write(entryOffsetBuffer);
            entryOffsetBuffer.clear();

            entryOffsetBuffer.putInt(strEntryLastOffset);
            strEntryLastOffset += curEntry.len();
        }

        if (entryBuffer.remaining() >= curEntry.len()) {
            MemoryUtil.copyMemory(curEntry.addr(), entryBufferAddr + entryBuffer.position(), curEntry.len());
            entryBuffer.position(entryBuffer.position() + curEntry.len());
        } else {
            entryBuffer.flip();
            entryFile.write(entryBuffer);
            entryBuffer.clear();

            MemoryUtil.copyMemory(curEntry.addr(), entryBufferAddr + entryBuffer.position(), curEntry.len());
            entryBuffer.position(entryBuffer.position() + curEntry.len());
        }
        if (isFlush) {
            entryOffsetBuffer.flip();
            entryOffsetFile.write(entryOffsetBuffer);
            entryOffsetBuffer.clear();

            // n string entries got n+1 offsets
            entryOffsetBuffer.putInt(strEntryLastOffset);
            entryOffsetBuffer.flip();
            entryOffsetFile.write(entryOffsetBuffer);
            entryOffsetBuffer.clear();

            entryBuffer.flip();
            entryFile.write(entryBuffer);
            entryBuffer.clear();
        }
    }
}
